import matplotlib.pyplot as plt
from matplotlib.animation import FuncAnimation
import numpy as np

def make_plots(pde,ode,system_name,plot_kernel=True,plot_PDF=True,plot_loss=True,make_gif=True,x_lim=None,plot_intermediates=True,):
    z_i  = pde.z_i
    rho0 = pde.rho0
    rho  = pde.rho
    x    = ode.x
    nT   = pde.nT
    g0   = pde.g0
    x0   = ode.x0
    nRho = pde.nRho
    dt   = ode.dt
    if plot_intermediates:
        rho_history = pde.rho_history
        x_history   = ode.x_history
    if make_gif:
        rho_history = pde.rho_history

    ########## plot kernel #################
    cropped_z_i = z_i[z_i<-z_i[0]]
    try:
        W_vec = np.vectorize(pde.W)
        val = W_vec(cropped_z_i,0,1.)
    except:
        print("in exception")
        val = pde.W(0,cropped_z_i,1.)
    if plot_kernel:
        plt.figure()
        plt.plot(cropped_z_i,-val)
        plt.title("Kernel")
        plt.xlabel("z")
        plt.ylabel("W(z)")
        # xlim = plt.xlim()
        # plt.xlim([xlim[0],-xlim[0]])
        plt.savefig(system_name+"_kernel.png")

    ########## plot PDF #################
    if plot_PDF:
        fontsize=16
        plt.figure()
        plt.plot(z_i,rho0,color="#1f77b4",alpha=0.5,linestyle="--",label=r"$\tilde{\rho}$")
        plt.plot(z_i,rho,color="#1f77b4",label=r"$\rho$")
        plt.plot(z_i,g0,color="#ff7f0e",label=r"$\bar{\rho}$")
        if x_lim is not None:
            plt.xlim(x_lim)
        ylim = plt.ylim()
        xlim = plt.xlim()
        plt.plot([x0,x0],ylim,alpha=0.3,color='k',linestyle="-.")
        plt.plot([x,x],ylim,alpha=0.5,color='k',linestyle="-.",label="x")
        ax = plt.gca()
        ax.fill_between(z_i,rho0,alpha=0.15,color="#1f77b4")
        ax.fill_between(z_i,rho,alpha=0.3,color="#1f77b4")
        ax.fill_between(z_i,g0,alpha=0.3,color="#ff7f0e")
        y_val = np.mean(ylim)
        plt.arrow(x0,y_val,x-x0,0,length_includes_head=True,color="k",width=0.005,head_width=0.05,head_length=0.2)
        if plot_intermediates:
            # plot values midway
            plt.plot([x_history[0],x_history[0]],ylim,alpha=0.35,color='k',linestyle="-.")
            plt.plot([x_history[1],x_history[1]],ylim,alpha=0.4,color='k',linestyle="-.")
            # plt.plot([x_history[2],x_history[2]],ylim,alpha=0.45,color='k',linestyle="-.")

            for idx_ in range(2):
                plt.plot(z_i,rho_history[:,idx_],color="#1f77b4",alpha=0.5+0.05*idx_,linestyle="--")
                ax.fill_between(z_i,rho_history[:,idx_],alpha=0.15+0.05*idx_,color="#1f77b4")
        
        plt.legend(fontsize=fontsize)
        
        # for now, plot halfway for both distributions
        # halfway_idx = np.where(np.cumsum(rho)<0.5*np.sum(rho))[-1][-1]
        # plt.plot([z_i[halfway_idx],z_i[halfway_idx]],ylim,color='m')
        plt.ylabel("probability",fontsize=fontsize)
        plt.xlabel("z",fontsize=fontsize)
        ax.tick_params(axis='x', labelsize=fontsize)
        ax.tick_params(axis='y', labelsize=fontsize)
        plt.tight_layout()
        plt.savefig(system_name+".png")

def plot_four_windows(ode,pde,title,filename):
    z_i  = pde.z_i
    rho0 = pde.rho0
    rho  = pde.rho
    g0   = pde.g0
    x0   = ode.x0
    timestamps = ode.saveat
    timestamps = np.insert(timestamps,0,0)
    timestamps = np.append(timestamps,pde.nT)
    rho_history = np.hstack([np.expand_dims(rho0,1),pde.rho_history,np.expand_dims(rho,1)])
    x_history   = np.hstack([x0,ode.x_history,ode.x])
    fontsize  = 32
    titlesize = 36
    plt.figure()
    f,ax=plt.subplots(nrows=1,ncols=4,figsize=(45,8),sharey=True)
    ax[0].set_ylabel("probability",fontsize=fontsize)
    for idx_ in range(4):
        ax[idx_].plot(z_i,g0,color="#ff7f0e",label=r"$\bar{\rho}$")
        ax[idx_].fill_between(z_i,g0,alpha=0.3,color="#ff7f0e")
        ax[idx_].plot(z_i,rho_history[:,idx_],color="#1f77b4",alpha=0.9,label=r"$\rho$")
        ax[idx_].fill_between(z_i,rho_history[:,idx_],alpha=0.2,color="#1f77b4")
        ax[idx_].set_xlim([-2.5,4.98])
        ax[idx_].tick_params(axis='x', labelsize=fontsize)
        ax[idx_].tick_params(axis='y', labelsize=fontsize)
        ax[idx_].set_xlabel("z",fontsize=fontsize)
        ax[idx_].annotate("t="+str(round(timestamps[idx_]*ode.dt,ndigits=1)),(0.05,0.85),xycoords="axes fraction",fontsize=fontsize)

        
        if idx_==0:
            ylim = plt.ylim()
        if idx_>0:
            ax[idx_].plot(z_i,rho_history[:,0],color="#1f77b4",alpha=0.1,label=r"$\rho_0$")
            ax[idx_].fill_between(z_i,rho_history[:,0],alpha=0.1,color="#1f77b4")
            ax[idx_].plot([x0,x0],ylim,alpha=0.5,color='k',linestyle="--",lw=2)
            ax[idx_].plot([x_history[idx_],x_history[idx_]],ylim,color='k',linestyle="--",label="x",lw=3)
        if idx_==0:
            ax[idx_].plot([x0,x0],ylim,color='k',linestyle="--",label="x",lw=3)
        ax[idx_].legend(loc=1,fontsize=fontsize)
    plt.suptitle(title,fontsize=titlesize)

    plt.tight_layout()
    plt.savefig("plots/"+filename+".png")